# Copyright 2018 The trfl Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorFlow ops for the Retrace algorithm and continuous actions.

Safe and Efficient Off-Policy Reinforcement Learning
R. Munos, T. Stepleton, A. Harutyunyan, M. G. Bellemare
https://arxiv.org/abs/1606.02647

This variant is commonly used to update the Q function in RS0, which
additionally uses SVG or a SVG variant to update the policy.

Learning by Playing - Solving Sparse Reward Tasks from Scratch
M. Riedmiller, R. Hafner, T. Lampe, M. Neunert, J. Degrave, T. Van de Wiele,
V. Mnih, N. Heess, J. T. Springenberg
https://arxiv.org/abs/1802.10567

Learning Continuous Control Policies by Stochastic Value Gradients
N. Heess, G. Wayne, D. Silver, T. Lillicrap, Y. Tassa, T. Erez
https://arxiv.org/abs/1510.09142

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

import tensorflow.compat.v1 as tf


QTraceReturns = collections.namedtuple("QTraceReturns", [
    "qs", "importance_weights", "log_importance_weights",
    "truncated_importance_weights", "deltas", "vs_minus_q_xs"
])


def retrace_from_action_log_probs(
    behaviour_action_log_probs,
    target_action_log_probs,
    discounts,
    rewards,
    q_values,
    values,
    bootstrap_value,
    lambda_=1.,
    name="retrace_from_action_log_probs"):
  """Constructs Q/Retrace ops.

  This is an implementation of Retrace. In the description of the arguments
  the notation is as follows: `T` refers to the sequence size over which
  the return is calculated, finally `B` denotes the batch size.

  Args:
    behaviour_action_log_probs: Log-probabilities. Shape [T, B].
    target_action_log_probs: Log-probabilities for target policy. Shape [T, B]
    discounts: Also called pcontinues. Discount encountered when following
      the behaviour policy. Shape [T, B].
    rewards: A tensor containing rewards generated by following the behaviour
      policy. Shape [T, B].
    q_values: Q-function estimates wrt. the target policy. Shape [T, B].
    values: Value function estimates wrt. the target policy. Shape [T, B].
    bootstrap_value: Value function estimate at time `T`. Shape [B].
    lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1).
    name: The name scope that all qtrace ops will be created in.

  Returns:
    A `QTraceReturns` namedtuple containing:

        * qs: The Retrace regression/policy gradient targets.
          Can be used to calculate estimates of the advantage for policy
          gradients or as regression target for Q-value functions. Shape [T, B].
        * importance_weights: Importance sampling weights. Shape [T, B].
        * log_importance_weights: Importance sampling weights. Shape [T, B].
        * truncated_importance_weights: Called c_t in the paper. Shape [T, B].
        * deltas: Shape [T, B]
        * vs_minus_q_xs: Q-Retrace targets - Q(x_s, u_s). Shape [T, B].
  """
  # Turn arguments to tensors.
  behaviour_action_log_probs = tf.convert_to_tensor(
      behaviour_action_log_probs, dtype=tf.float32)
  target_action_log_probs = tf.convert_to_tensor(
      target_action_log_probs, dtype=tf.float32)
  values = tf.convert_to_tensor(values, dtype=tf.float32)
  q_values = tf.convert_to_tensor(q_values, dtype=tf.float32)
  bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32)
  discounts = tf.convert_to_tensor(discounts, dtype=tf.float32)
  rewards = tf.convert_to_tensor(rewards, dtype=tf.float32)

  # Make sure tensor ranks are as expected.
  behaviour_action_log_probs.get_shape().assert_has_rank(2)
  target_action_log_probs.get_shape().assert_has_rank(2)
  values.get_shape().assert_has_rank(2)
  q_values.get_shape().assert_has_rank(2)
  bootstrap_value.get_shape().assert_has_rank(1)
  discounts.get_shape().assert_has_rank(2)
  rewards.get_shape().assert_has_rank(2)

  with tf.name_scope(
      name,
      values=[
          behaviour_action_log_probs, target_action_log_probs, discounts,
          rewards, q_values, values, bootstrap_value
      ]):
    log_rhos = target_action_log_probs - behaviour_action_log_probs
    return retrace_from_importance_weights(
        log_rhos=log_rhos,
        discounts=discounts,
        rewards=rewards,
        q_values=q_values,
        values=values,
        bootstrap_value=bootstrap_value,
        lambda_=lambda_)


def retrace_from_importance_weights(log_rhos,
                                    discounts,
                                    rewards,
                                    q_values,
                                    values,
                                    bootstrap_value,
                                    lambda_=1.0,
                                    name="retrace_from_importance_weights"):
  """Constructs Q/Retrace ops.

  This is an implementation of Retrace. In the description of the arguments
  the notation is as follows: `T` refers to the sequence size over which
  the return is calculated, finally `B` denotes the batch size.

  Args:
    log_rhos: Log-probabilities for target policy. Shape [T, B]
    discounts: Also called pcontinues. Discount encountered when following
      the behaviour policy. Shape [T, B].
    rewards: A tensor containing rewards generated by following the behaviour
      policy. Shape [T, B].
    q_values: Q-function estimates wrt. the target policy. Shape [T, B].
    values: Value function estimates wrt. the target policy. Shape [T, B].
    bootstrap_value: Value function estimate at time `T`. Shape [B].
    lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1).
    name: The name scope that all qtrace ops will be created in.

  Returns:
    A `QTraceReturns` namedtuple containing:

      * qs: The Retrace regression/policy gradient targets.
        Can be used to calculate estimates of the advantage for policy
        gradients or as regression target for Q-value functions. Shape [T, B].
      * importance_weights: Importance sampling weights. Shape [T, B].
      * log_importance_weights: Importance sampling weights. Shape [T, B].
      * truncated_importance_weights: Called c_t in the paper. Shape [T, B].
      * deltas: Shape [T, B]
      * vs_minus_q_xs: Q-Retrace targets - Q(x_s, u_s). Shape [T, B].

  Raises:
    ValueError: If compiled=True, but log_rhos has rank other than 2.
  """
  # Make sure tensor ranks are consistent.
  rho_rank = log_rhos.get_shape().ndims  # Usually 2.
  q_values.get_shape().assert_has_rank(rho_rank)
  values.get_shape().assert_has_rank(rho_rank)
  bootstrap_value.get_shape().assert_has_rank(rho_rank - 1)
  discounts.get_shape().assert_has_rank(rho_rank)
  rewards.get_shape().assert_has_rank(rho_rank)

  lambda_ = tf.convert_to_tensor(lambda_, dtype=tf.float32)

  with tf.name_scope(
      name, values=[log_rhos, discounts, rewards, values, bootstrap_value]):
    rhos = tf.exp(log_rhos)

    cs = tf.minimum(1.0, rhos, name="cs")

    # Set the last c to 1.
    cs = tf.concat([cs[1:], tf.ones_like(cs[-1:])], axis=0)
    cs *= lambda_

    # Append bootstrapped value to get [v1, ..., v_t+1]
    values_t_plus_1 = tf.concat(
        [values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)

    # delta_t = (r_t + discount * V(x_{t+1}) - Q(x_t, a_t))
    deltas = (rewards + discounts * values_t_plus_1 - q_values)

    # Note that all sequences are reversed, computation starts from the back.
    sequences = (
        tf.reverse(discounts, axis=[0]),
        tf.reverse(cs, axis=[0]),
        tf.reverse(deltas, axis=[0]),
    )

    # Re-trace vs are calculated through a scan from the back to the beginning
    # of the given trajectory.
    def scanfunc(acc, sequence_item):
      discount_t, c_t, delta_t = sequence_item
      return delta_t + discount_t * c_t * acc

    initial_values = tf.zeros_like(bootstrap_value)
    vs_minus_q_xs = tf.scan(
        fn=scanfunc,
        elems=sequences,
        initializer=initial_values,
        parallel_iterations=1,
        back_prop=False,
        name="scan")
    # Reverse the results back to original order.
    vs_minus_q_xs = tf.reverse(vs_minus_q_xs, [0], name="vs_minus_q_xs")

    # Add V(x_s) to get q targets.
    qs = tf.add(vs_minus_q_xs, q_values, name="s")

    result = QTraceReturns(
        qs=tf.stop_gradient(qs),
        importance_weights=tf.stop_gradient(rhos),
        log_importance_weights=tf.stop_gradient(log_rhos),
        truncated_importance_weights=tf.stop_gradient(cs),
        deltas=tf.stop_gradient(deltas),
        vs_minus_q_xs=tf.stop_gradient(vs_minus_q_xs))
    return result
